{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b2d84944",
   "metadata": {},
   "source": [
    "# Using Custom DataBuilder in SecretFlow (Torch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17489bf0",
   "metadata": {},
   "source": [
    "The following codes are demos only. It's **NOT for production** due to system security concerns, please **DO NOT** use it directly in production."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3ec8e75",
   "metadata": {},
   "source": [
    "This tutorial will demonstrate how to use the custom DataBuilder mode to load data and train models in the multi-party secure environment of SecretFlow.\n",
    "\n",
    "The tutorial will use the image classification task of the Flower dataset to illustrate how to utilize the custom DataBuilder for federated learning in SecretFlow."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34f4426e",
   "metadata": {},
   "source": [
    "## Environment Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bb62f0dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9aca94a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-17 15:14:51,955\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "sf.init(['alice', 'bob', 'charlie'], address=\"local\", log_to_driver=False)\n",
    "alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85476766",
   "metadata": {},
   "source": [
    "## Interface Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "83350c9f",
   "metadata": {},
   "source": [
    "In SecretFlow, we have supported the ability to customize the DataBuilder for reading in the `FLModel`. This allows users to handle data input more flexibly according to their specific requirements.\n",
    "\n",
    "Below, we provide an example to demonstrate how to use the custom DataBuilder for federated model training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6d6d5890",
   "metadata": {},
   "source": [
    "Steps for using DataBuilder:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cb71476",
   "metadata": {},
   "source": [
    "1. Develop the DataBuilder function for constructing the DataLoader under the PyTorch engine in the single-machine version. *Note: The dataset_builder function requires the 'stage' parameter.*\n",
    "2. Wrap the DataBuilder functions of each party to obtain create_dataset_builder.\n",
    "3. Construct data_builder_dict [PYU, dataset_builder].\n",
    "4. Pass the obtained data_builder_dict as an argument to the `dataset_builder` in the `fit` function. At this point, provide the required input to the dataset_builder in the `x` parameter position. (For example, in this case, the input provided is the actual image paths used)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a014150c",
   "metadata": {},
   "source": [
    "In FLModel, using DataBuilder requires predefining a databuilder dictionary, which needs to be able to return `tf.dataset` and `steps_per_epoch`. Moreover, the `steps_per_epoch` returned by each party must remain consistent.\n",
    "```python\n",
    "data_builder_dict = \n",
    "        {\n",
    "            alice: create_alice_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_alice_dataset_builder must return (Dataset, steps_per_epoch)\n",
    "            bob: create_bob_dataset_builder(\n",
    "                batch_size=32,\n",
    "            ), # create_bob_dataset_builder must return (Dataset, steps_per_epochstep_per_epochs)\n",
    "        }\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "872d44a2",
   "metadata": {},
   "source": [
    "## Download Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3fdc6fb",
   "metadata": {},
   "source": [
    "Introduction to the Flower Dataset: The Flower dataset is a collection of 4323 color images containing 5 different types of flowers(namely, tulips, daffodils, irises, lilies, and sunflowers). Each flower category comprises multiple images captured from various angles and under different lighting conditions. The resolution of each image is 320x240. This dataset is commonly used for image classification and training/testing machine learning algorithms. The number of samples in each category is as follows: daisies (633), dandelions (898), roses (641), sunflowers (699), and tulips (852).\n",
    "  \n",
    "Download link:[http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)\n",
    "\n",
    "<img alt=\"flower_dataset_demo.png\" src=\"https://www.secretflow.org.cn/static/flower_dataset_demo.553ea776.png\" width=\"600\">  \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65fd8419",
   "metadata": {},
   "source": [
    "### Download data and extract"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4ad06caa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\n",
      "67588319/67588319 [==============================] - 1s 0us/step\n"
     ]
    }
   ],
   "source": [
    "# The TensorFlow interface is reused to download images , and the output is a folder, as shown in the following figure.\n",
    "import tempfile\n",
    "import tensorflow as tf\n",
    "\n",
    "_temp_dir = tempfile.mkdtemp()\n",
    "path_to_flower_dataset = tf.keras.utils.get_file(\n",
    "    \"flower_photos\",\n",
    "    \"https://secretflow-data.oss-accelerate.aliyuncs.com/datasets/tf_flowers/flower_photos.tgz\",\n",
    "    untar=True,\n",
    "    cache_dir=_temp_dir,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c60f3b6",
   "metadata": {},
   "source": [
    "## Next, we proceed to construct a custom DataBuilder."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90548868",
   "metadata": {},
   "source": [
    "### 1. Develop DataBuilder using a single-machine engine."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7690b723",
   "metadata": {},
   "source": [
    "In the development of the `DataBuilder`, we are free to follow the logic of single-machine development. The objective is to construct a `Dataloader` object in `Torch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "59a766f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.sampler import SubsetRandomSampler\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# parameter\n",
    "batch_size = 32\n",
    "shuffle = True\n",
    "random_seed = 1234\n",
    "train_split = 0.8\n",
    "\n",
    "# Define dataset\n",
    "flower_transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize((180, 180)),\n",
    "        transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "flower_dataset = datasets.ImageFolder(\n",
    "    path_to_flower_dataset, transform=flower_transform\n",
    ")\n",
    "dataset_size = len(flower_dataset)\n",
    "# Define sampler\n",
    "\n",
    "indices = list(range(dataset_size))\n",
    "if shuffle:\n",
    "    np.random.seed(random_seed)\n",
    "    np.random.shuffle(indices)\n",
    "split = int(np.floor(train_split * dataset_size))\n",
    "train_indices, val_indices = indices[:split], indices[split:]\n",
    "train_sampler = SubsetRandomSampler(train_indices)\n",
    "valid_sampler = SubsetRandomSampler(val_indices)\n",
    "\n",
    "# Define databuilder\n",
    "train_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler=train_sampler)\n",
    "valid_loader = DataLoader(flower_dataset, batch_size=batch_size, sampler=valid_sampler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dec7044c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x.shape = torch.Size([32, 3, 180, 180])\n",
      "y.shape = torch.Size([32])\n"
     ]
    }
   ],
   "source": [
    "x, y = next(iter(train_loader))\n",
    "print(f\"x.shape = {x.shape}\")\n",
    "print(f\"y.shape = {y.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85d6833a",
   "metadata": {},
   "source": [
    "### 2. Wrap the developed DataBuilder."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad896183",
   "metadata": {},
   "source": [
    "The DataBuilder we have developed needs to be distributed and executed on various computing machines during runtime. To facilitate serialization, we need to wrap them.\n",
    "\n",
    "It is essential to consider the following points:\n",
    "\n",
    "- FLModel requires that the input to DataBuilder must include the stage parameter (stage=\"train\").\n",
    "- FLModel requires that the passed DataBuilder must return two results, namely, `data_set` and `steps_per_epoch`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9a502be1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_dataset_builder(\n",
    "    batch_size=32,\n",
    "    train_split=0.8,\n",
    "    shuffle=True,\n",
    "    random_seed=1234,\n",
    "):\n",
    "    def dataset_builder(x, stage=\"train\"):\n",
    "        \"\"\" \"\"\"\n",
    "        import math\n",
    "\n",
    "        import numpy as np\n",
    "        from torch.utils.data import DataLoader\n",
    "        from torch.utils.data.sampler import SubsetRandomSampler\n",
    "        from torchvision import datasets, transforms\n",
    "\n",
    "        # Define dataset\n",
    "        flower_transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.Resize((180, 180)),\n",
    "                transforms.ToTensor(),\n",
    "            ]\n",
    "        )\n",
    "        flower_dataset = datasets.ImageFolder(x, transform=flower_transform)\n",
    "        dataset_size = len(flower_dataset)\n",
    "        # Define sampler\n",
    "\n",
    "        indices = list(range(dataset_size))\n",
    "        if shuffle:\n",
    "            np.random.seed(random_seed)\n",
    "            np.random.shuffle(indices)\n",
    "        split = int(np.floor(train_split * dataset_size))\n",
    "        train_indices, val_indices = indices[:split], indices[split:]\n",
    "        train_sampler = SubsetRandomSampler(train_indices)\n",
    "        valid_sampler = SubsetRandomSampler(val_indices)\n",
    "\n",
    "        # Define databuilder\n",
    "        train_loader = DataLoader(\n",
    "            flower_dataset, batch_size=batch_size, sampler=train_sampler\n",
    "        )\n",
    "        valid_loader = DataLoader(\n",
    "            flower_dataset, batch_size=batch_size, sampler=valid_sampler\n",
    "        )\n",
    "\n",
    "        # Return\n",
    "        if stage == \"train\":\n",
    "            train_step_per_epoch = math.ceil(split / batch_size)\n",
    "\n",
    "            return train_loader, train_step_per_epoch\n",
    "        elif stage == \"eval\":\n",
    "            eval_step_per_epoch = math.ceil((dataset_size - split) / batch_size)\n",
    "            return valid_loader, eval_step_per_epoch\n",
    "\n",
    "    return dataset_builder"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f36a07f",
   "metadata": {},
   "source": [
    "### 3. Construct the dataset_builder_dict."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1b45c27b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare dataset dict\n",
    "data_builder_dict = {\n",
    "    alice: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "        train_split=0.8,\n",
    "        shuffle=False,\n",
    "        random_seed=1234,\n",
    "    ),\n",
    "    bob: create_dataset_builder(\n",
    "        batch_size=32,\n",
    "        train_split=0.8,\n",
    "        shuffle=False,\n",
    "        random_seed=1234,\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b1fb480",
   "metadata": {},
   "source": [
    "### 4. Once we obtain the `dataset_builder_dict`, we can proceed with federated training using it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f8fc010",
   "metadata": {},
   "source": [
    "Next, we define a FLModel with a Torch backend for training."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56ec78af",
   "metadata": {},
   "source": [
    "#### Define the Model Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bdeb305f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "from secretflow.ml.nn.core.torch import BaseModule\n",
    "\n",
    "\n",
    "class ConvRGBNet(BaseModule):\n",
    "    def __init__(self, *args, **kwargs) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.network = nn.Sequential(\n",
    "            nn.Conv2d(\n",
    "                in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1\n",
    "            ),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(16 * 45 * 45, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 5),\n",
    "        )\n",
    "\n",
    "    def forward(self, xb):\n",
    "        return self.network(xb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c1801328",
   "metadata": {},
   "outputs": [],
   "source": [
    "from secretflow.ml.nn import FLModel\n",
    "from secretflow.security.aggregation import SecureAggregator\n",
    "from torch import nn, optim\n",
    "from torchmetrics import Accuracy, Precision\n",
    "from secretflow.ml.nn.core.torch import metric_wrapper, optim_wrapper, TorchModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c1939910",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.security.aggregation.secure_aggregator._Masker'> with party bob.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party alice.\n",
      "INFO:root:Create proxy actor <class 'secretflow.ml.nn.fl.backend.torch.strategy.fed_avg_w.PYUFedAvgW'> with party bob.\n"
     ]
    }
   ],
   "source": [
    "device_list = [alice, bob]\n",
    "aggregator = SecureAggregator(charlie, [alice, bob])\n",
    "# prepare model\n",
    "num_classes = 5\n",
    "\n",
    "input_shape = (180, 180, 3)\n",
    "# torch model\n",
    "loss_fn = nn.CrossEntropyLoss\n",
    "optim_fn = optim_wrapper(optim.Adam, lr=1e-3)\n",
    "model_def = TorchModel(\n",
    "    model_fn=ConvRGBNet,\n",
    "    loss_fn=loss_fn,\n",
    "    optim_fn=optim_fn,\n",
    "    metrics=[\n",
    "        metric_wrapper(\n",
    "            Accuracy, task=\"multiclass\", num_classes=num_classes, average='micro'\n",
    "        ),\n",
    "        metric_wrapper(\n",
    "            Precision, task=\"multiclass\", num_classes=num_classes, average='micro'\n",
    "        ),\n",
    "    ],\n",
    ")\n",
    "\n",
    "fed_model = FLModel(\n",
    "    device_list=device_list,\n",
    "    model=model_def,\n",
    "    aggregator=aggregator,\n",
    "    backend=\"torch\",\n",
    "    strategy=\"fed_avg_w\",\n",
    "    random_seed=1234,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "650d095d",
   "metadata": {},
   "source": [
    "The input to our constructed dataset builder is the path to the image dataset; hence, here, we need to set the input data as a `Dict`.\n",
    "```python\n",
    "data = {\n",
    "    alice: folder_path_of_alice,\n",
    "    bob: folder_path_of_bob\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5f2cfe44",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:FL Train Params: {'self': <secretflow.ml.nn.fl.fl_model.FLModel object at 0x7ff4efc87af0>, 'x': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'y': None, 'batch_size': 32, 'batch_sampling_rate': None, 'epochs': 5, 'verbose': 1, 'callbacks': None, 'validation_data': {alice: '/tmp/tmp59nrtvl5/datasets/flower_photos', bob: '/tmp/tmp59nrtvl5/datasets/flower_photos'}, 'shuffle': False, 'class_weight': None, 'sample_weight': None, 'validation_freq': 1, 'aggregate_freq': 2, 'label_decoder': None, 'max_batch_size': 20000, 'prefetch_buffer_size': None, 'sampler_method': 'batch', 'random_seed': 1234, 'dp_spent_step_freq': 1, 'audit_log_dir': None, 'dataset_builder': {alice: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff600fb7ee0>, bob: <function create_dataset_builder.<locals>.dataset_builder at 0x7ff6007148b0>}}\n",
      "100%|██████████| 30/30 [00:32<00:00,  1.08s/it, epoch: 1/5 -  multiclassaccuracy:0.3760416805744171  multiclassprecision:0.3760416805744171  val_multiclassaccuracy:0.0  val_multiclassprecision:0.0 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 2/5 -  multiclassaccuracy:0.5078125  multiclassprecision:0.5078125  val_multiclassaccuracy:0.1618257313966751  val_multiclassprecision:0.1618257313966751 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 3/5 -  multiclassaccuracy:0.51171875  multiclassprecision:0.51171875  val_multiclassaccuracy:0.004149377811700106  val_multiclassprecision:0.004149377811700106 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.27s/it, epoch: 4/5 -  multiclassaccuracy:0.5390625  multiclassprecision:0.5390625  val_multiclassaccuracy:0.02074688859283924  val_multiclassprecision:0.02074688859283924 ]\n",
      "100%|██████████| 8/8 [00:10<00:00,  1.28s/it, epoch: 5/5 -  multiclassaccuracy:0.5703125  multiclassprecision:0.5703125  val_multiclassaccuracy:0.016597511246800423  val_multiclassprecision:0.016597511246800423 ]\n"
     ]
    }
   ],
   "source": [
    "data = {\n",
    "    alice: path_to_flower_dataset,\n",
    "    bob: path_to_flower_dataset,\n",
    "}\n",
    "history = fed_model.fit(\n",
    "    data,\n",
    "    None,\n",
    "    validation_data=data,\n",
    "    epochs=5,\n",
    "    batch_size=32,\n",
    "    aggregate_freq=2,\n",
    "    sampler_method=\"batch\",\n",
    "    random_seed=1234,\n",
    "    dp_spent_step_freq=1,\n",
    "    dataset_builder=data_builder_dict,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
